# coding:utf-8
from numpy.ma import array

from src.util.picture_util import PictureUtil
from src.manager.oracle_manager import OracleManager
from src.manager.log_manager import LogManager
from src.config.oracle_config import OracleConfig

Logger = LogManager.get_logger(__name__)


class ClosePriceMa5GoldCrossPictureHandler:
    """
    创建收盘价金叉五日均线相关的图
    """

    def __init__(self):
        """
        构造函数，初始化oracle_manager和cursor对象
        """
        self.oracle_manager = OracleManager(OracleConfig.Username, OracleConfig.Password, OracleConfig.Url)
        self.oracle_manager.connect()
        self.cursor = self.oracle_manager.get_cursor()

    def get_k_d_average_profit_loss_data(self):
        """
        获取x轴数据和y轴数据，其中x轴为k、d的平均值
        :return:
        """
        Logger.info('获取x轴数据和y轴数据，其中x轴为k、d的平均值')
        self.cursor.execute(
            "select t.*, t1.k, t1.d from MDL_CLOSE_PRICE_MA5_GOLD_CROSS t "
            "join stock_transaction_data t1 on (t1.code_=t.stock_code and t1.date_=t.buy_date) ")
        # "where t.profit_loss<=200")
        mdl_close_price_ma5_gold_cross_list = self.cursor.fetchall()
        x_data_list = []
        profile_loss_list = []
        for mdl_close_price_ma5_gold_cross in mdl_close_price_ma5_gold_cross_list:
            if mdl_close_price_ma5_gold_cross[10] != None and mdl_close_price_ma5_gold_cross[11] != None \
                    and mdl_close_price_ma5_gold_cross[6] != None:
                k_d_average = (mdl_close_price_ma5_gold_cross[10] + mdl_close_price_ma5_gold_cross[11]) / 2
                profit_loss = mdl_close_price_ma5_gold_cross[6]
                x_data_list.append(k_d_average)
                profile_loss_list.append(profit_loss)
        self.oracle_manager.cursor_close()
        self.oracle_manager.connect_close()
        return x_data_list, profile_loss_list

    def get_dif_dea_average_profit_loss_data(self):
        """
        获取x轴数据和y轴数据，其中x轴为dif和dea的平均值
        :return:
        """
        Logger.info('获取x轴数据和y轴数据，其中x轴为dif和dea的平均值')
        self.cursor.execute(
            "select t.*, t1.dif, t1.dea from MDL_CLOSE_PRICE_MA5_GOLD_CROSS t "
            "join stock_transaction_data t1 on (t1.code_=t.stock_code and t1.date_=t.buy_date) ")
        # "where t.profit_loss<=200")
        mdl_close_price_ma5_gold_cross_list = self.cursor.fetchall()
        dif_dea_average_list = []
        profile_loss_list = []
        for mdl_close_price_ma5_gold_cross in mdl_close_price_ma5_gold_cross_list:
            if mdl_close_price_ma5_gold_cross[10] != None and mdl_close_price_ma5_gold_cross[11] != None \
                    and mdl_close_price_ma5_gold_cross[6] != None:
                dif_dea_average = (mdl_close_price_ma5_gold_cross[10] + mdl_close_price_ma5_gold_cross[11]) / 2
                profit_loss = mdl_close_price_ma5_gold_cross[6]
                dif_dea_average_list.append(dif_dea_average)
                profile_loss_list.append(profit_loss)
        self.oracle_manager.cursor_close()
        self.oracle_manager.connect_close()
        return dif_dea_average_list, profile_loss_list

    def get_k_d_dea_average_and_dif_dea_average_and_label_data(self):
        self.cursor.execute(
            "select t.*, t1.k, t1.d, t1.dif, t1.dea, "
            "case when t.profit_loss>0 then 1 when t.profit_loss<0 then -1 else 0 end "
            "from MDL_CLOSE_PRICE_MA5_GOLD_CROSS t "
            "join stock_transaction_data t1 on (t1.code_=t.stock_code and t1.date_=t.buy_date)")
        mdl_close_price_ma5_gold_cross_list = self.cursor.fetchall()
        k_d_average_list = []
        dif_dea_average_list = []
        label_list = []
        for mdl_close_price_ma5_gold_cross in mdl_close_price_ma5_gold_cross_list:
            if mdl_close_price_ma5_gold_cross[10] != None and mdl_close_price_ma5_gold_cross[11] != None \
                    and mdl_close_price_ma5_gold_cross[12] != None and mdl_close_price_ma5_gold_cross[13] != None:
                k_d_average = (mdl_close_price_ma5_gold_cross[10] + mdl_close_price_ma5_gold_cross[11]) / 2
                dif_dea_average = (mdl_close_price_ma5_gold_cross[12] + mdl_close_price_ma5_gold_cross[13]) / 2
                label = mdl_close_price_ma5_gold_cross[14]
                k_d_average_list.append(k_d_average)
                dif_dea_average_list.append(dif_dea_average)
                label_list.append(label)
        self.oracle_manager.cursor_close()
        self.oracle_manager.connect_close()
        return k_d_average_list, dif_dea_average_list, label_list

    def create_close_price_ma5_gold_cross_refer_k_d_scatter_picture(self):
        """
        生成散点图：x轴为k、d的平均值，y轴为mdl_close_price_ma5_gold_cross表的profit_loss字段
        :return:
        """
        Logger.info('生成散点图：x轴为k、d的平均值，y轴为mdl_close_price_ma5_gold_cross表的profit_loss字段')
        x_data_list, y_data_list = self.get_k_d_average_profit_loss_data()
        PictureUtil.scatter(x_data_list, y_data_list)

    def create_close_price_ma5_gold_cross_refer_dif_dea_scatter_picture(self):
        """
        生成散点图：x轴为dif和dea的平均值，y轴为mdl_close_price_ma5_gold_cross表的profit_loss字段
        :return:
        """
        Logger.info('生成散点图：x轴为dif和dea的平均值，y轴为mdl_close_price_ma5_gold_cross表的profit_loss字段')
        x_data_list, y_data_list = self.get_dif_dea_average_profit_loss_data()
        PictureUtil.scatter(x_data_list, y_data_list)

    def create_close_price_ma5_gold_cross_refer_k_d_and_dif_dea_scatter_picture(self):
        """
        生成散点图：x轴为k、d的平均值，y轴为dif和dea的平均值，profit_loss为-1、0、1时，圆点的颜色不一样
        :return:
        """
        x_data_list, y_data_list, label_list = self.get_k_d_dea_average_and_dif_dea_average_and_label_data()
        # 参数s和c不能为0
        PictureUtil.scatter(x_data_list, y_data_list, 1.0 * (array(label_list) + 2), 15.0 * (array(label_list) + 2))
